// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_DNN_ACCELERATOR_HLS_SYSTEMC_INPUTCONTROLLER_H_
#define THIRD_PARTY_DNN_ACCELERATOR_HLS_SYSTEMC_INPUTCONTROLLER_H_

#include <mc_connections.h>
#include <systemc.h>

#include "src/AccelTypes.h"

template <typename DTYPE, int NROWS>
SC_MODULE(InputController) {
  sc_in<bool> CCS_INIT_S1(clk);
  sc_in<bool> CCS_INIT_S1(rstn);

  Connections::In<Params> CCS_INIT_S1(paramsIn);

  Connections::Out<MemoryRequest> CCS_INIT_S1(addressRequest);
  Connections::In<Pack1D<DTYPE, NROWS> > CCS_INIT_S1(dataResponse);

  Connections::Out<int> bufferWriteAddress[2];
  Connections::Out<Pack1D<DTYPE, NROWS> > bufferWriteData[2];
  Connections::Out<int> bufferWriteControl[2];
  Connections::Out<int> bufferReadAddress[2];
  Connections::Out<int> bufferReadControl[2];

  Connections::Combinational<Params> CCS_INIT_S1(fetcherParams);
  Connections::Combinational<Params> CCS_INIT_S1(writerParams);
  Connections::Combinational<Params> CCS_INIT_S1(readerParams);

  SC_CTOR(InputController) {
    SC_THREAD(read_params);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);

    SC_THREAD(fetcher);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);

    SC_THREAD(reader);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);

    SC_THREAD(writer);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);
  }

  void fetcher() {
    addressRequest.Reset();
    fetcherParams.ResetRead();

    wait();
    while (true) {
      Params params = fetcherParams.Pop();

      int FX = params.loops[1][params.fxIndex];
      int FY = params.loops[1][params.fyIndex];

      // create array of loop counters for ability to index into counters
      int loop_counters[2][6];
      int loop_bounds[2][6];

#pragma hls_unroll yes
      for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 6; j++) {
          loop_bounds[i][j] = params.loops[i][j];
        }
      }

      loop_bounds[1][params.weightLoopIndex[1]] = 1;
      loop_bounds[1][params.fxIndex] = 1;
      loop_bounds[1][params.fyIndex] = 1;

#pragma hls_pipeline_init_interval 1
#pragma hls_pipeline_stall_mode flush
      // outer memory hierarchy
      for (loop_counters[0][0] = 0; loop_counters[0][0] < loop_bounds[0][0];
           loop_counters[0][0]++) {
        for (loop_counters[0][1] = 0; loop_counters[0][1] < loop_bounds[0][1];
             loop_counters[0][1]++) {
          for (loop_counters[0][2] = 0; loop_counters[0][2] < loop_bounds[0][2];
               loop_counters[0][2]++) {
            // deal with border pixels
            loop_bounds[1][params.inputXLoopIndex[1]] =
                params.loops[1][params.inputXLoopIndex[1]] * params.STRIDE;
            loop_bounds[1][params.inputYLoopIndex[1]] =
                params.loops[1][params.inputYLoopIndex[1]] * params.STRIDE;
            int x_min_offset = 0;
            int y_min_offset = 0;
            int x_max_offset = 0;
            int y_max_offset = 0;

            // left and top border tiles
            if (loop_counters[0][params.inputXLoopIndex[0]] != 0) {
              x_min_offset = (FX - 1) / 2;
              loop_bounds[1][params.inputXLoopIndex[1]] += x_min_offset;
            }
            if (loop_counters[0][params.inputYLoopIndex[0]] != 0) {
              y_min_offset = (FY - 1) / 2;
              loop_bounds[1][params.inputYLoopIndex[1]] += y_min_offset;
            }

            // right and bottom border tiles
            if (loop_counters[0][params.inputXLoopIndex[0]] !=
                loop_bounds[0][params.inputXLoopIndex[0]] - 1) {
              x_max_offset = (FX - 1) / 2;
              loop_bounds[1][params.inputXLoopIndex[1]] += x_max_offset;
            }
            if (loop_counters[0][params.inputYLoopIndex[0]] !=
                loop_bounds[0][params.inputYLoopIndex[0]] - 1) {
              y_max_offset = (FY - 1) / 2;
              loop_bounds[1][params.inputYLoopIndex[1]] += y_max_offset;
            }

            for (loop_counters[1][0] = 0;
                 loop_counters[1][0] < loop_bounds[1][0];
                 loop_counters[1][0]++) {
              for (loop_counters[1][1] = 0;
                   loop_counters[1][1] < loop_bounds[1][1];
                   loop_counters[1][1]++) {
                for (loop_counters[1][2] = 0;
                     loop_counters[1][2] < loop_bounds[1][2];
                     loop_counters[1][2]++) {
                  for (loop_counters[1][3] = 0;
                       loop_counters[1][3] < loop_bounds[1][3];
                       loop_counters[1][3]++) {
                    for (loop_counters[1][4] = 0;
                         loop_counters[1][4] < loop_bounds[1][4];
                         loop_counters[1][4]++) {
                      for (loop_counters[1][5] = 0;
                           loop_counters[1][5] < loop_bounds[1][5];
                           loop_counters[1][5]++) {
                        int x0 = loop_counters[1][params.inputXLoopIndex[1]];
                        int x1 = loop_counters[0][params.inputXLoopIndex[0]];
                        int X0 = params.STRIDE *
                                 params.loops[1][params.inputXLoopIndex[1]];
                        int X1 = params.STRIDE *
                                 params.loops[0][params.inputXLoopIndex[0]];
                        int x = (x0 - x_min_offset) + x1 * X0;
                        int X = X0 * X1;

                        int y0 = loop_counters[1][params.inputYLoopIndex[1]];
                        int y1 = loop_counters[0][params.inputYLoopIndex[0]];
                        int Y0 = params.STRIDE *
                                 params.loops[1][params.inputYLoopIndex[1]];
                        int y = (y0 - y_min_offset) + y1 * Y0;

                        int c1 = loop_counters[1][params.reductionLoopIndex[1]];
                        int C1 = params.loops[1][params.reductionLoopIndex[1]];
                        int c = c1 * NROWS;
                        int C = C1 * NROWS;

                        int baseAddress = y * X * C + x * C + c;
                        int burstSize = NROWS;

                        if (params.CONCAT_HEAD) {
                          int HEAD_SIZE = (1 << params.HEAD_SZ_LG2);
                          baseAddress = ((c / HEAD_SIZE) * X * HEAD_SIZE) +
                                        (x * HEAD_SIZE) + (c % HEAD_SIZE);
                        }

                        MemoryRequest memRequest = {
                            params.INPUT_OFFSET + baseAddress, burstSize};
                        addressRequest.Push(memRequest);
                      }
                    }
                  }
                }
              }
            }
          }
        }
      }
    }
  }

  void writer() {
    writerParams.ResetRead();
    dataResponse.Reset();

    bufferWriteControl[0].Reset();
    bufferWriteControl[1].Reset();

    bufferWriteAddress[0].Reset();
    bufferWriteAddress[1].Reset();

    bufferWriteData[0].Reset();
    bufferWriteData[1].Reset();

    wait();
    while (true) {
      Params params = writerParams.Pop();

      bool bankSel = 0;

      int FX = params.loops[1][params.fxIndex];
      int FY = params.loops[1][params.fyIndex];
      int STRIDE = params.STRIDE;

      int fx_bound = (FX - 1) / 2;
      int fy_bound = (FY - 1) / 2;

      // create array of loop counters for ability to index into counters
      int loop_counters[2][6];
      int loop_bounds[2][6];

#pragma hls_unroll yes
      for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 6; j++) {
          loop_bounds[i][j] = params.loops[i][j];
        }
      }

      loop_bounds[1][params.weightLoopIndex[1]] = 1;
      loop_bounds[1][params.fxIndex] = 1;
      loop_bounds[1][params.fyIndex] = 1;

#pragma hls_pipeline_init_interval 1
#pragma hls_pipeline_stall_mode flush
      // outer memory hierarchy
      for (loop_counters[0][0] = 0; loop_counters[0][0] < loop_bounds[0][0];
           loop_counters[0][0]++) {
        for (loop_counters[0][1] = 0; loop_counters[0][1] < loop_bounds[0][1];
             loop_counters[0][1]++) {
          for (loop_counters[0][2] = 0; loop_counters[0][2] < loop_bounds[0][2];
               loop_counters[0][2]++) {
            loop_bounds[1][params.inputXLoopIndex[1]] =
                params.loops[1][params.inputXLoopIndex[1]] * STRIDE;
            loop_bounds[1][params.inputYLoopIndex[1]] =
                params.loops[1][params.inputYLoopIndex[1]] * STRIDE;
            int x_min_offset = fx_bound;
            int y_min_offset = fy_bound;
            loop_bounds[1][params.inputXLoopIndex[1]] += FX - 1;
            loop_bounds[1][params.inputYLoopIndex[1]] += FY - 1;

            for (loop_counters[1][0] = 0;
                 loop_counters[1][0] < loop_bounds[1][0];
                 loop_counters[1][0]++) {
              int total_writes =
                  (loop_bounds[1][1] * loop_bounds[1][2] * loop_bounds[1][3] *
                   loop_bounds[1][4] * loop_bounds[1][5]);
              bufferWriteControl[bankSel].Push(total_writes);

              for (loop_counters[1][1] = 0;
                   loop_counters[1][1] < loop_bounds[1][1];
                   loop_counters[1][1]++) {
                for (loop_counters[1][2] = 0;
                     loop_counters[1][2] < loop_bounds[1][2];
                     loop_counters[1][2]++) {
                  for (loop_counters[1][3] = 0;
                       loop_counters[1][3] < loop_bounds[1][3];
                       loop_counters[1][3]++) {
                    for (loop_counters[1][4] = 0;
                         loop_counters[1][4] < loop_bounds[1][4];
                         loop_counters[1][4]++) {
                      for (loop_counters[1][5] = 0;
                           loop_counters[1][5] < loop_bounds[1][5];
                           loop_counters[1][5]++) {
                        int x0 = loop_counters[1][params.inputXLoopIndex[1]];
                        int x1 = loop_counters[0][params.inputXLoopIndex[0]];
                        int X0 = params.STRIDE *
                                 params.loops[1][params.inputXLoopIndex[1]];
                        int X1 = params.STRIDE *
                                 params.loops[0][params.inputXLoopIndex[0]];

                        int y0 = loop_counters[1][params.inputYLoopIndex[1]];
                        int y1 = loop_counters[0][params.inputYLoopIndex[0]];
                        int Y0 = params.STRIDE *
                                 params.loops[1][params.inputYLoopIndex[1]];
                        int Y1 = params.STRIDE *
                                 params.loops[0][params.inputYLoopIndex[0]];

                        int full_x =
                            (x0 - x_min_offset) + x1 * params.STRIDE * X0;
                        int full_y =
                            (y0 - y_min_offset) + y1 * params.STRIDE * Y0;

                        Pack1D<DTYPE, NROWS> data;

                        // out of bounds
                        if ((full_x < 0) || (full_y < 0) ||
                            (full_x >= STRIDE * X0 * X1) ||
                            (full_y >= STRIDE * Y0 * Y1)) {
#pragma hls_unroll yes
                          for (int dims = 0; dims < NROWS; dims++) {
                            data[dims] = 0;
                          }
                        } else {
                          data = dataResponse.Pop();
                        }

                        int address = y0 * (STRIDE * X0 + FX - 1) + x0;

                        bufferWriteAddress[bankSel].Push(address);
                        bufferWriteData[bankSel].Push(data);
                      }
                    }
                  }
                }
              }
              bankSel = !bankSel;
            }
          }
        }
      }
    }
  }

  void reader() {
    readerParams.ResetRead();

    bufferReadControl[0].Reset();
    bufferReadControl[1].Reset();
    bufferReadAddress[0].Reset();
    bufferReadAddress[1].Reset();

    wait();
    while (true) {
      Params params = readerParams.Pop();

      bool bankSel = 0;

      int FX = params.loops[1][params.fxIndex];

      // create array of loop counters for ability to index into counters
      int loop_counters[2][6];
      int loop_bounds[2][6];

#pragma hls_unroll yes
      for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 6; j++) {
          loop_bounds[i][j] = params.loops[i][j];
        }
      }

#pragma hls_pipeline_init_interval 1
#pragma hls_pipeline_stall_mode flush
      // outer memory hierarchy
      for (loop_counters[0][0] = 0; loop_counters[0][0] < loop_bounds[0][0];
           loop_counters[0][0]++) {
        for (loop_counters[0][1] = 0; loop_counters[0][1] < loop_bounds[0][1];
             loop_counters[0][1]++) {
          for (loop_counters[0][2] = 0; loop_counters[0][2] < loop_bounds[0][2];
               loop_counters[0][2]++) {
            // inner memory hierarchy
            for (loop_counters[1][0] = 0;
                 loop_counters[1][0] < loop_bounds[1][0];
                 loop_counters[1][0]++) {
              int total_reads =
                  (loop_bounds[1][1] * loop_bounds[1][2] * loop_bounds[1][3] *
                   loop_bounds[1][4] * loop_bounds[1][5]);
              bufferReadControl[bankSel].Push(total_reads);
              for (loop_counters[1][1] = 0;
                   loop_counters[1][1] < loop_bounds[1][1];
                   loop_counters[1][1]++) {
                for (loop_counters[1][2] = 0;
                     loop_counters[1][2] < loop_bounds[1][2];
                     loop_counters[1][2]++) {
                  for (loop_counters[1][3] = 0;
                       loop_counters[1][3] < loop_bounds[1][3];
                       loop_counters[1][3]++) {
                    for (loop_counters[1][4] = 0;
                         loop_counters[1][4] < loop_bounds[1][4];
                         loop_counters[1][4]++) {
                      for (loop_counters[1][5] = 0;
                           loop_counters[1][5] < loop_bounds[1][5];
                           loop_counters[1][5]++) {
                        int x0 = loop_counters[1][params.inputXLoopIndex[1]];
                        int X0 = params.loops[1][params.inputXLoopIndex[1]];
                        int y0 = loop_counters[1][params.inputYLoopIndex[1]];
                        int fx = loop_counters[1][params.fxIndex];
                        int fy = loop_counters[1][params.fyIndex];

                        int x = params.STRIDE * x0 + fx;
                        int y = params.STRIDE * y0 + fy;

                        int address = y * (params.STRIDE * X0 + FX - 1) + x;

                        bufferReadAddress[bankSel].Push(address);
                      }
                    }
                  }
                }
              }
              bankSel = !bankSel;
            }
          }
        }
      }
    }
  }

  void read_params() {
    paramsIn.Reset();
    fetcherParams.ResetWrite();
    writerParams.ResetWrite();
    readerParams.ResetWrite();

    wait();

    while (true) {
      Params params = paramsIn.Pop();

      fetcherParams.Push(params);
      writerParams.Push(params);
      readerParams.Push(params);
    }
  }
};

#endif  // THIRD_PARTY_DNN_ACCELERATOR_HLS_SYSTEMC_INPUTCONTROLLER_H_
